{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.TSiTPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TSiT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a PyTorch implementation created by Ignacio Oguiza (oguiza@timeseriesAI.co) based on ViT (Vision Transformer):\n",
    "\n",
    "Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020).\n",
    "\n",
    "<span style=\"color:dodgerblue\">**An image is worth 16x16 words: Transformers for image recognition at scale**</span>. arXiv preprint arXiv:2010.11929."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.utils import *\n",
    "from tsai.models.layers import *\n",
    "from typing import Callable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _TSiTEncoderLayer(nn.Module):\n",
    "    def __init__(self, d_model:int, n_heads:int, q_len:int=None, attn_dropout:float=0., dropout:float=0, drop_path_rate:float=0.,\n",
    "                 mlp_ratio:int=1, lsa:bool=False, qkv_bias:bool=True, act:str='gelu', pre_norm:bool=False):\n",
    "        super().__init__()\n",
    "        self.mha =  MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, proj_dropout=dropout, lsa=lsa, qkv_bias=qkv_bias)\n",
    "        self.attn_norm = nn.LayerNorm(d_model)\n",
    "        self.pwff =  PositionwiseFeedForward(d_model, dropout=dropout, act=act, mlp_ratio=mlp_ratio)\n",
    "        self.ff_norm = nn.LayerNorm(d_model)\n",
    "        self.drop_path = DropPath(drop_path_rate) if drop_path_rate != 0 else nn.Identity()\n",
    "        self.pre_norm = pre_norm\n",
    "\n",
    "        if lsa and q_len is not None:\n",
    "            self.register_buffer('attn_mask', torch.eye(q_len).reshape(1, 1, q_len, q_len).bool())\n",
    "        else: self.attn_mask = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.pre_norm:\n",
    "            if self.attn_mask is not None:\n",
    "                x = self.drop_path(self.mha(self.attn_norm(x), attn_mask=self.attn_mask)[0]) + x\n",
    "            else:\n",
    "                x = self.drop_path(self.mha(self.attn_norm(x))[0]) + x\n",
    "            x = self.drop_path(self.pwff(self.ff_norm(x))) + x\n",
    "        else:\n",
    "            if self.attn_mask is not None:\n",
    "                x = self.attn_norm(self.drop_path(self.mha(x, attn_mask=self.attn_mask)[0]) + x)\n",
    "            else:\n",
    "                x = self.attn_norm(self.drop_path(self.mha(x)[0]) + x)\n",
    "            x = self.ff_norm(self.drop_path(self.pwff(x)) + x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xb = torch.rand(16, 51, 128, requires_grad=True).to(default_device())\n",
    "output = _TSiTEncoderLayer(128, 16).to(default_device())(xb)\n",
    "# Then try a simple backward pass\n",
    "target = torch.rand_like(output)\n",
    "loss = ((output - target)**2).mean()\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _TSiTEncoder(nn.Module):\n",
    "    def __init__(self, d_model, n_heads, depth:int=6, q_len:int=None, attn_dropout:float=0., dropout:float=0, drop_path_rate:float=0.,\n",
    "                 mlp_ratio:int=1, lsa:bool=False, qkv_bias:bool=True, act:str='gelu', pre_norm:bool=False):\n",
    "        super().__init__()\n",
    "        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]\n",
    "        layers = []\n",
    "        for i in range(depth):\n",
    "            layer = _TSiTEncoderLayer(d_model, n_heads, q_len=q_len, attn_dropout=attn_dropout, dropout=dropout, drop_path_rate=dpr[i],\n",
    "                                      mlp_ratio=mlp_ratio, lsa=lsa, qkv_bias=qkv_bias, act=act, pre_norm=pre_norm)\n",
    "            layers.append(layer)\n",
    "        self.encoder = nn.Sequential(*layers)\n",
    "        self.norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        x = self.norm(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xb = torch.rand(16, 51, 128, requires_grad=True).to(default_device())\n",
    "output = _TSiTEncoder(128, 16).to(default_device())(xb)\n",
    "# Then try a simple backward pass\n",
    "target = torch.rand_like(output)\n",
    "loss = ((output - target)**2).mean()\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class _TSiTBackbone(Module):\n",
    "    def __init__(self, c_in:int, seq_len:int, depth:int=6, d_model:int=128, n_heads:int=16, act:str='gelu',\n",
    "                 lsa:bool=False, qkv_bias:bool=True, attn_dropout:float=0., dropout:float=0., drop_path_rate:float=0., mlp_ratio:int=1,\n",
    "                 pre_norm:bool=False, use_token:bool=True,  use_pe:bool=True, n_cat_embeds:Optional[list]=None, cat_embed_dims:Optional[list]=None,\n",
    "                 cat_padding_idxs:Optional[list]=None, cat_pos:Optional[list]=None, feature_extractor:Optional[Callable]=None,\n",
    "                 token_size:int=None, tokenizer:Optional[Callable]=None):\n",
    "\n",
    "        # Categorical embeddings\n",
    "        if n_cat_embeds is not None:\n",
    "            n_cat_embeds = listify(n_cat_embeds)\n",
    "            if cat_embed_dims is None:\n",
    "                cat_embed_dims = [emb_sz_rule(s) for s in n_cat_embeds]\n",
    "            self.to_cat_embed = MultiEmbedding(c_in, n_cat_embeds, cat_embed_dims=cat_embed_dims, cat_padding_idxs=cat_padding_idxs, cat_pos=cat_pos)\n",
    "            c_in, seq_len = output_size_calculator(self.to_cat_embed, c_in, seq_len)\n",
    "        else:\n",
    "            self.to_cat_embed = nn.Identity()\n",
    "\n",
    "        # Sequence embedding\n",
    "        if token_size is not None:\n",
    "            self.tokenizer = SeqTokenizer(c_in, d_model, token_size)\n",
    "            c_in, seq_len = output_size_calculator(self.tokenizer, c_in, seq_len)\n",
    "        elif tokenizer is not None:\n",
    "            if isinstance(tokenizer, nn.Module):  self.tokenizer = tokenizer\n",
    "            else: self.tokenizer = tokenizer(c_in, d_model)\n",
    "            c_in, seq_len = output_size_calculator(self.tokenizer, c_in, seq_len)\n",
    "        else:\n",
    "            self.tokenizer = nn.Identity()\n",
    "\n",
    "        # Feature extractor\n",
    "        if feature_extractor is not None:\n",
    "            if isinstance(feature_extractor, nn.Module):  self.feature_extractor = feature_extractor\n",
    "            else: self.feature_extractor = feature_extractor(c_in, d_model)\n",
    "            c_in, seq_len = output_size_calculator(self.feature_extractor, c_in, seq_len)\n",
    "        else:\n",
    "            self.feature_extractor = nn.Identity()\n",
    "\n",
    "        # Linear projection\n",
    "        if token_size is None and tokenizer is None and feature_extractor is None:\n",
    "            # self.linear_proj = nn.Conv1d(c_in, d_model, 1)\n",
    "            self.linear_proj = nn.Linear(c_in, d_model)\n",
    "        else:\n",
    "            self.linear_proj = nn.Identity()\n",
    "\n",
    "        self.transpose = Transpose(1,2)\n",
    "\n",
    "        # Position embedding & token\n",
    "        if use_pe:\n",
    "            self.pos_embed = nn.Parameter(torch.zeros(1, seq_len, d_model))\n",
    "        self.use_pe = use_pe\n",
    "        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))\n",
    "        self.use_token = use_token\n",
    "        self.emb_dropout = nn.Dropout(dropout)\n",
    "\n",
    "        # Encoder\n",
    "        self.encoder = _TSiTEncoder(d_model, n_heads, depth=depth, q_len=seq_len + use_token, qkv_bias=qkv_bias, lsa=lsa, dropout=dropout,\n",
    "                                    mlp_ratio=mlp_ratio, drop_path_rate=drop_path_rate, act=act, pre_norm=pre_norm)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        # Categorical embeddings\n",
    "        x = self.to_cat_embed(x)\n",
    "\n",
    "        # Sequence embedding\n",
    "        x = self.tokenizer(x)\n",
    "\n",
    "        # Feature extractor\n",
    "        x = self.feature_extractor(x)\n",
    "\n",
    "        # Linear projection\n",
    "        x = self.transpose(x)\n",
    "        x = self.linear_proj(x)\n",
    "\n",
    "        # Position embedding & token\n",
    "        if self.use_pe:\n",
    "            x = x + self.pos_embed\n",
    "        if self.use_token: # token is concatenated after position embedding so that embedding can be learned using self.supervised learning\n",
    "            x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n",
    "        x = self.emb_dropout(x)\n",
    "\n",
    "        # Encoder\n",
    "        x = self.encoder(x)\n",
    "\n",
    "        # Output\n",
    "        x = x.transpose(1,2)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class TSiTPlus(nn.Sequential):\n",
    "    r\"\"\"Time series transformer model based on ViT (Vision Transformer):\n",
    "\n",
    "    Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020).\n",
    "    An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.\n",
    "\n",
    "    This implementation is a modified version of Vision Transformer that is part of the grat timm library\n",
    "    (https://github.com/rwightman/pytorch-image-models/blob/72b227dcf57c0c62291673b96bdc06576bb90457/timm/models/vision_transformer.py)\n",
    "\n",
    "    Args:\n",
    "        c_in:               the number of features (aka variables, dimensions, channels) in the time series dataset.\n",
    "        c_out:              the number of target classes.\n",
    "        seq_len:            number of time steps in the time series.\n",
    "        d_model:            total dimension of the model (number of features created by the model).\n",
    "        depth:              number of blocks in the encoder.\n",
    "        n_heads:            parallel attention heads. Default:16 (range(8-16)).\n",
    "        act:                the activation function of positionwise feedforward layer.\n",
    "        lsa:                locality self attention used (see Lee, S. H., Lee, S., & Song, B. C. (2021). Vision Transformer for Small-Size Datasets.\n",
    "                            arXiv preprint arXiv:2112.13492.)\n",
    "        attn_dropout:       dropout rate applied to the attention sublayer.\n",
    "        dropout:            dropout applied to to the embedded sequence steps after position embeddings have been added and\n",
    "                            to the mlp sublayer in the encoder.\n",
    "        drop_path_rate:     stochastic depth rate.\n",
    "        mlp_ratio:          ratio of mlp hidden dim to embedding dim.\n",
    "        qkv_bias:           determines whether bias is applied to the Linear projections of queries, keys and values in the MultiheadAttention\n",
    "        pre_norm:           if True normalization will be applied as the first step in the sublayers. Defaults to False.\n",
    "        use_token:          if True, the output will come from the transformed token. This is meant to be use in classification tasks.\n",
    "        use_pe:             flag to indicate if positional embedding is used.\n",
    "        n_cat_embeds:       list with the sizes of the dictionaries of embeddings (int).\n",
    "        cat_embed_dims:     list with the sizes of each embedding vector (int).\n",
    "        cat_padding_idxs:       If specified, the entries at cat_padding_idxs do not contribute to the gradient; therefore, the embedding vector at cat_padding_idxs\n",
    "                            are not updated during training. Use 0 for those categorical embeddings that may have #na# values. Otherwise, leave them as None.\n",
    "                            You can enter a combination for different embeddings (for example, [0, None, None]).\n",
    "        cat_pos:            list with the position of the categorical variables in the input.\n",
    "        token_size:         Size of the embedding function used to reduce the sequence length (similar to ViT's patch size)\n",
    "        tokenizer:          nn.Module or callable that will be used to reduce the sequence length\n",
    "        feature_extractor:  nn.Module or callable that will be used to preprocess the time series before\n",
    "                            the embedding step. It is useful to extract features or resample the time series.\n",
    "        flatten:            flag to indicate if the 3d logits will be flattened to 2d in the model's head if use_token is set to False.\n",
    "                            If use_token is False and flatten is False, the model will apply a pooling layer.\n",
    "        concat_pool:        if True the head begins with fastai's AdaptiveConcatPool2d if concat_pool=True; otherwise, it uses traditional average pooling.\n",
    "        fc_dropout:         dropout applied to the final fully connected layer.\n",
    "        use_bn:             flag that indicates if batchnorm will be applied to the head.\n",
    "        bias_init:          values used to initialized the output layer.\n",
    "        y_range:            range of possible y values (used in regression tasks).\n",
    "        custom_head:        custom head that will be applied to the network. It must contain all kwargs (pass a partial function)\n",
    "        verbose:            flag to control verbosity of the model.\n",
    "\n",
    "    Input:\n",
    "        x: bs (batch size) x nvars (aka features, variables, dimensions, channels) x seq_len (aka time steps)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, c_in:int, c_out:int, seq_len:int, d_model:int=128, depth:int=6, n_heads:int=16, act:str='gelu',\n",
    "                 lsa:bool=False, attn_dropout:float=0., dropout:float=0., drop_path_rate:float=0., mlp_ratio:int=1, qkv_bias:bool=True,\n",
    "                 pre_norm:bool=False, use_token:bool=False, use_pe:bool=True,\n",
    "                 cat_pos:Optional[list]=None, n_cat_embeds:Optional[list]=None, cat_embed_dims:Optional[list]=None, cat_padding_idxs:Optional[list]=None,\n",
    "                 token_size:int=None, tokenizer:Optional[Callable]=None, feature_extractor:Optional[Callable]=None,\n",
    "                 flatten:bool=False, concat_pool:bool=True, fc_dropout:float=0., use_bn:bool=False,\n",
    "                 bias_init:Optional[Union[float, list]]=None, y_range:Optional[tuple]=None, custom_head:Optional[Callable]=None, verbose:bool=True, **kwargs):\n",
    "\n",
    "        if use_token and c_out == 1:\n",
    "            use_token = False\n",
    "            pv(\"use_token set to False as c_out == 1\", verbose)\n",
    "        backbone = _TSiTBackbone(c_in, seq_len, depth=depth, d_model=d_model, n_heads=n_heads, act=act,\n",
    "                                 lsa=lsa, attn_dropout=attn_dropout, dropout=dropout, drop_path_rate=drop_path_rate,\n",
    "                                 pre_norm=pre_norm, mlp_ratio=mlp_ratio, use_pe=use_pe, use_token=use_token,\n",
    "                                 n_cat_embeds=n_cat_embeds, cat_embed_dims=cat_embed_dims, cat_padding_idxs=cat_padding_idxs, cat_pos=cat_pos,\n",
    "                                 feature_extractor=feature_extractor, token_size=token_size, tokenizer=tokenizer)\n",
    "\n",
    "        self.head_nf = d_model\n",
    "        self.c_out = c_out\n",
    "        self.seq_len = seq_len\n",
    "\n",
    "        # Head\n",
    "        if custom_head:\n",
    "            if isinstance(custom_head, nn.Module): head = custom_head\n",
    "            else:\n",
    "                head = custom_head(self.head_nf, c_out, seq_len, **kwargs)\n",
    "        else:\n",
    "            nf = d_model\n",
    "            layers = []\n",
    "            if use_token:\n",
    "                layers += [TokenLayer()]\n",
    "            elif flatten:\n",
    "                layers += [Reshape(-1)]\n",
    "                nf = nf * seq_len\n",
    "            else:\n",
    "                if concat_pool: nf *= 2\n",
    "                layers = [GACP1d(1) if concat_pool else GAP1d(1)]\n",
    "            if use_bn: layers += [nn.BatchNorm1d(nf)]\n",
    "            if fc_dropout: layers += [nn.Dropout(fc_dropout)]\n",
    "\n",
    "            # Last layer\n",
    "            linear = nn.Linear(nf, c_out)\n",
    "            if bias_init is not None:\n",
    "                if isinstance(bias_init, float): nn.init.constant_(linear.bias, bias_init)\n",
    "                else: linear.bias = nn.Parameter(torch.as_tensor(bias_init, dtype=torch.float32))\n",
    "            layers += [linear]\n",
    "\n",
    "            if y_range: layers += [SigmoidRange(*y_range)]\n",
    "            head = nn.Sequential(*layers)\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "\n",
    "\n",
    "TSiT = TSiTPlus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "\n",
    "xb = torch.rand(bs, nvars, seq_len, requires_grad=True).to(default_device())\n",
    "model = TSiTPlus(nvars, c_out, seq_len, attn_dropout=.1, dropout=.1, use_token=True)\n",
    "output = model.to(default_device())(xb)\n",
    "target = torch.rand_like(output)\n",
    "loss = ((output - target)**2).mean()\n",
    "loss.backward()\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "model = TSiTPlus(nvars, c_out, seq_len, attn_dropout=.1, dropout=.1, use_token=True)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "model = TSiTPlus(nvars, c_out, seq_len, attn_dropout=.1, dropout=.1, use_token=False)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "model(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "bias_init = np.array([0.8, .2])\n",
    "model = TSiTPlus(nvars, c_out, seq_len, bias_init=bias_init)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "test_eq(model.head[1].bias.data, tensor(bias_init))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 1\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "bias_init = 8.5\n",
    "model = TSiTPlus(nvars, c_out, seq_len, bias_init=bias_init)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "test_eq(model.head[1].bias.data, tensor([bias_init]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "bias_init = np.array([0.8, .2])\n",
    "model = TSiTPlus(nvars, c_out, seq_len, bias_init=bias_init, lsa=True)\n",
    "test_eq(model(xb).shape, (bs, c_out))\n",
    "test_eq(model.head[1].bias.data, tensor(bias_init))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.models.layers import lin_nd_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 50\n",
    "c_out = 2\n",
    "d = 7\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "model = TSiTPlus(nvars, c_out, seq_len, d=7, custom_head=lin_nd_head)\n",
    "test_eq(model(xb).shape, (bs, d, c_out))\n",
    "\n",
    "xb = torch.rand(bs, nvars, seq_len, requires_grad=True)\n",
    "output = model(xb)\n",
    "# Then try a simple backward pass\n",
    "target = torch.rand_like(output)\n",
    "loss = ((output - target)**2).mean()\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feature extractor\n",
    "\n",
    "It's a known fact that transformers cannot be directly applied to long sequences. To avoid this, we have included a way to subsample the sequence to generate a more manageable input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.data.validation import get_splits\n",
    "from tsai.data.core import get_ts_dls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABZcAAABoCAYAAACNDM73AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAGDpJREFUeJzt3QtwVPXZx/EnNwhsLhAugUAgFIMiVwlIQYowCiKVShQpii1gi1qUtE2xNY4moEhUlKFcolBa0qlQUSuYKliBCgG8cEcpELkTKxWqQEgUAmHfef7TzZvAkuQku3t2N9/PzM6ym91znj3xuMkvzz7/EKfT6RQAAAAAAAAAACwItfJgAAAAAAAAAAAU4TIAAAAAAAAAwDLCZQAAAAAAAACAZYTLAAAAAAAAAADLCJcBAAAAAAAAAJYRLgMAAAAAAAAALCNcBgAAAAAAAABYRrgMAAAAAAAAALCMcBkAAAAAAAAAYBnhMgAAgJfk5uZKSEiIHDlypPy+QYMGmYunTZ061eyroqSkJBk/frx4m74+3be+Xhfdb1RUlPiK7l+PAQAAAADfIVwGAAD4n88++0xGjRol7du3l8jISGnTpo0MGTJE5s6d67V9fvnllyYU3blzp/iDlStX+m1I68+1AQAAAPVRuN0FAAAA+IMPP/xQBg8eLO3atZOJEydKq1atpLCwUD7++GP5/e9/L5MnT/bIft5///0rwuVp06aZLuOePXuKJxUUFEhoaKjlAHf+/PmWQlwN47/77juJiIioRZWeqU33Hx7Oj7YAAACAL/ETOAAAgIg8++yzEhsbK1u2bJEmTZpU+tqJEyc8tp8GDRqIrzRs2NCr27948aJcunTJvCbt9LaT3fsHAAAA6iPGYgAAAIjIwYMHpUuXLlcEy6ply5ZXzPd99NFHZcmSJXLttdeaYDMlJUXy8/Or3U/Fmcvr1q2TPn36mH9PmDDBbPfy2cXubNy40TxP99uxY0dZsGCB28ddPnP5woULpks6OTnZPLdZs2YyYMAAWb16tfm6PlY7g12v0XWpOFf5xRdflNmzZ5v9ani9Z88etzOXXQ4dOiS33XabOBwOSUhIkKefflqcTmf51/UY6HP1uqLLt1lVba77Lu9o3rFjh9x+++0SExNj5j/fcsstphPd3VzsTZs2SXp6urRo0cLUmpqaKidPnqzy+wAAAADUd3QuAwAA/G+0w0cffSS7d++Wrl27Vvv49evXy7JlyyQtLc2ErDk5OTJs2DDZvHlzjZ6vOnfubMLWzMxMefDBB+UHP/iBub9///5VzoUeOnSoCUE1TNXu4aysLImPj692f/r47Oxs+fnPfy433nijFBUVydatW2X79u1mtvRDDz1kxnRo2PyXv/zF7TYWL14s586dM/Xq646LizPdy+6UlZWZY/L9739fXnjhBXnvvfdMrVqzvm4ralJbRf/617/M8dRg+be//a0Z2aEhvAb7+r3r27dvpcfr2JOmTZua+jTY1gBd/4Cg32MAAAAA7hEuAwAAiMiUKVNMl6vOPdbgVYNJ7XTVOczuZglrCK3BrHYsqzFjxpguZg2K33rrrRrtUwNh3ac+p1+/fnL//fdX+xx9rHb+btiwwcyHVnfffbd069at2ue+++67Mnz4cFm4cKHbr2sNnTp1MgHu1Wr54osv5MCBAybcdtEw1h0NoTVcnjNnjrk9adIkGTFihDz//PMmlG/evHm1NVupraInn3zSdGprl/f3vvc9c99Pf/pT8z3SsFkD5oq0i1vnYbu6oTUw17rPnDljxqUAAAAAuBJjMQAAAERM5652Lv/oRz+SXbt2mU5bHefQpk0bycvLcxt2uoJlpUHvnXfeKf/4xz9Mx6436HZ1+yNHjiwPll0d0FprdXTkh3b07t+/v9Y1aJBdMViujnb/Xj5OpLS0VNasWSPeosdJg2I9Tq5gWbVu3Vruu+8+Ezhr13ZF2oldccyG/nFBt3P06FGv1QkAAAAEOsJlAACA/9E5xtp1fOrUKTPeIiMjQ86ePSujRo0ys4Ur0rnFl9PO2m+//dZrs3p1u999953bfWtHbnV0FMXp06dNndrp/Nhjj8mnn35qqYYOHTrU+LGhoaGVwl2l+66q29lTx0m/D+6OiQbx2pVcWFhY6f6KYb3SERlK/1sAAAAA4B7hMgAAwGUaNGhgguYZM2bIyy+/bMYrvPHGGxLoBg4caBYu/NOf/mTmQi9atEh69eplrmuqUaNGHq2pYrdwRd7q/r6asLAwt/dXXHwQAAAAQGWEywAAAFXo3bu3uT5+/Hil+92Nlvj888+lcePGlsZGXC1cdUe3q+Guu30XFBTUaBu6AN+ECRPkr3/9q+ne7d69u1norzb1VEc7hA8dOnTFMVJJSUmVOoS1o7oid+MoalqbHif9Prg7Jvv27TMd1YmJiRZeCQAAAAB3CJcBAABE5IMPPnDbpbpy5UpzffmIBZ3PvH379vLbGtS+/fbbMnTo0Kt2wbrjcDjchqvu6HZ1tvKKFSvk2LFj5ffv3bvXzGKuztdff13pdlRUlFxzzTVy/vz5WtVTE/PmzSv/tx5fva0LJOpiiap9+/bmdeXn51d6Xk5OzhXbqmltuj39Puj3o+L4ja+++kqWLl0qAwYMkJiYmDq/NgAAAKC+C7e7AAAAAH8wefJkM6c3NTVVrrvuOrPo3IcffijLli0zXbba7VuRjpXQoDctLU0aNmxYHoZOmzbN0n47duxoFtp75ZVXJDo62gSoffv2vepsY93+e++9ZxacmzRpkly8eFHmzp0rXbp0qXZ+8vXXXy+DBg0yCxFqB/PWrVvlzTffrLTonmuRQn1d+vo0qB0zZozURmRkpKl13Lhx5jWtWrVK3n33XXniiSfKu7tjY2PlnnvuMa9BO5P1eLzzzjty4sSJK7Znpbbp06fL6tWrTZCsxyk8PFwWLFhggnRdrBEAAABA3REuAwAAiMiLL75o5iprp/LChQtNuKyLvGkw+eSTT5oAuKKbb75Z+vXrZ8Je7SLW4DY3N9eMmbBCu3j//Oc/m8UDH374YRMWL168+Krhsm5fu5TT09MlMzNT2rZta2rQsR3Vhcsayubl5cn7779vQlbtGtYQVhf2c7nrrrtM0P7aa6/Jq6++arqNaxsua/ir4fIvfvELsw8Nz7OyskzdFWmwrHOtNWDXoH706NEyc+ZME+BXZKU2Dds3bNhgjmt2drYZ0aEBtz5PrwEAAADUXYiTVUoAAAAs0Q7bRx55pNLIBwAAAACob5i5DAAAAAAAAACwjHAZAAAAAAAAAGAZ4TIAAAAAAAAAwDIW9AMAALCIJSsAAAAAgM5lAAAAAAAAAEAtEC4DAAAAAAAAAPx/LMalS5fkyy+/lOjoaAkJCfH17gEAAAAAAICAH9N29uxZSUhIkNBQekdRj8JlDZYTExN9vVsAAAAAAAAgqBQWFkrbtm3tLgP1mM/DZe1YVl1XdpUwR5ivdw8b7Lo53+4SAAAe0GP9QLtLgA/x/g0AAODPikQksTxnA+pNuOwahaHBclgU4XL9EGN3AQAAD+B9u77h/RsAAMDfMXIWdmMoCwAAAAAAAADAMsJlAAAAAAAAAIBlhMsAAAAAAAAAAP+fuQwAAAAAAAAA3lBWViYXLlywu4yAFRYWJuHh4TWe5024DAAAAAAAACDgFRcXyxdffCFOp9PuUgJa48aNpXXr1tKgQYNqH0u4DAAAAAAAACDgO5Y1WNZgtEWLFjXuvMX/01C+tLRUTp48KYcPH5bk5GQJDa16qjLhMgAAAAAAAICApqMwNBzVYLlRo0Z2lxOw9NhFRETI0aNHTdAcGRlZ5eNZ0A8AAAAAAABAUKBjue6q61au9FgP7A8AAAAAAAAAUM8QLgMAAAAAAAAALCNcBgAAAAAAAIAgkZSUJLNnz/bJvgiXAQAAAAAAAAQlHcHsy4vV+dBVXaZOnSq1sWXLFnnwwQfFL8Pl/Px8GTFihCQkJJgXuWLFCu9UBgAAAAAAAABB6vjx4+UX7TSOiYmpdN+UKVPKH+t0OuXixYs12m6LFi2kcePG4pfhcklJifTo0UPmz5/vnYoAAAAAAAAAIMi1atWq/BIbG2saeV239+3bJ9HR0bJq1SpJSUmRhg0bysaNG+XgwYNy5513Snx8vERFRUmfPn1kzZo1VY7F0O0uWrRIUlNTTeicnJwseXl59oTLt99+u0yfPt0UAwAAAAAAAADwjscff1yee+452bt3r3Tv3l2Ki4tl+PDhsnbtWtmxY4cMGzbMTJk4duxYlduZNm2ajB49Wj799FPz/LFjx8o333zj/zOXz58/L0VFRZUuAAAAAAAAAICqPf300zJkyBDp2LGjxMXFmYkSDz30kHTt2tV0ID/zzDPma9V1Io8fP17uvfdeueaaa2TGjBkmpN68ebP4fbicnZ1t2rpdl8TERG/vEgAAAAAAAAACXu/evSvd1lBYZzF37txZmjRpYkZjaFdzdZ3L2vXs4nA4zHznEydO+H+4nJGRIWfOnCm/FBYWenuXAAAAAAAAABDwHA5HpdsaLC9fvtx0H2/YsEF27twp3bp1k9LS0iq3ExERUem2zmG+dOlSnesLFy/TYdN6AQAAAAAAAADU3qZNm8yIC9d6eNrJfOTIEbGL1zuXAQAAAAAAAAB1p3OW33rrLdOxvGvXLrnvvvs80oHss85lTcMPHDhQfvvw4cPmxehA6Xbt2nm6PgAAAAAAAACoFadTgsqsWbPkgQcekP79+0vz5s3ld7/7nRQVFdlWT4jTae0Qr1u3TgYPHnzF/ePGjZPc3Nxqn68vVhf267G+h4RFhVmrFgFpe8o2u0sAAHhAr20pdpcAH+L9GwAAwJ9pmBhr1jfThdkgcu7cOdME26FDB4mMjLS7nHpzLC13Lg8aNEgs5tEAAAAAAAAAgCDDzGUAAAAAAAAAgGWEywAAAAAAAAAAywiXAQAAAAAAAACWES4DAAAAAAAAACwjXAYAAAAAAAAAWEa4DAAAAAAAAACwjHAZAAAAAAAAAGAZ4TIAAAAAAAAAwDLCZQAAAAAAAACAZeHWnwIAAAAAAAAA/i9le4pP97et17YaPzYkJKTKr2dlZcnUqVNrVYdue/ny5TJy5EjxJsJlAAAAAAAAAPCx48ePl/972bJlkpmZKQUFBeX3RUVFib/zebjsdDrNdVlJma93DdsU2V0AAMADyop5765feP8GAADw95/VXDkbAlOrVq3K/x0bG2u6jSvet2jRInnppZfk8OHDkpSUJGlpaTJp0iTztdLSUklPT5e//e1vcurUKYmPj5eHH35YMjIyzGNVamqquW7fvr0cOXIkOMLlr7/+2lzvHr7b17uGbWLtLgAA4AG7bra7AvgW798AAAD+TnM2DSURfJYsWWI6mefNmyc33HCD7NixQyZOnCgOh0PGjRsnc+bMkby8PHn99delXbt2UlhYaC5qy5Yt0rJlS1m8eLEMGzZMwsLCvFanz8PluLg4c33s2DH+4weCTFFRkSQmJpr/mcXExNhdDgAP4vwGghfnNxC8OL+B4HXmzBkTKLpyNgSfrKws07V81113mdsdOnSQPXv2yIIFC0y4rNlqcnKyDBgwwHQ8a3eyS4sWLcx1kyZNKnVCB0W4HBoaaq41WObNDQhOem5zfgPBifMbCF6c30Dw4vwGgpcrZ0NwKSkpkYMHD8rPfvYz063scvHixfJm3fHjx8uQIUPk2muvNd3Jd9xxhwwdOtTntbKgHwAAAAAAAAD4ieLiYnP9hz/8Qfr27Vvpa64RF7169TKzmFetWiVr1qyR0aNHy6233ipvvvmmT2slXAYAAAAAAAAAPxEfHy8JCQly6NAhGTt27FUfp59K+fGPf2wuo0aNMh3M33zzjRmXEhERIWVlZcEXLjds2NDMDNFrAMGF8xsIXpzfQPDi/AaCF+c3ELw4v4PftGnTJC0tzYzB0ND4/PnzsnXrVjl16pSkp6fLrFmzpHXr1maxPx2P8sYbb5j5yjpnWSUlJcnatWvlpptuMv+dNG3a1Ct1hjidTqdXtgwAAAAAAAAAPnDu3DkzJkIXvouMjJRAk5ubK7/61a/k9OnT5fctXbpUZs6caRbyczgc0q1bN/OY1NRUMzIjJydH9u/fb0Zl9OnTxzxWw2b197//3YTQR44ckTZt2phrbxxLwmUAAAAAAAAAAS3Qw+VAPZYsKQkAAAAAAAAAsIxwGQAAAAAAAABgGeEyAAAAAAAAAMAywmUAAAAAAAAAgH+Hy/Pnz5ekpCQzCLpv376yefNmX+4egBdkZ2ebFUmjo6OlZcuWMnLkSCkoKLC7LABe8Nxzz0lISIhZnRhA4Pv3v/8t999/vzRr1kwaNWpkVh/funWr3WUBqKOysjJ56qmnzCJMem537NhRnnnmGXE6nXaXBsCi/Px8GTFihCQkJJifw1esWFHp63peZ2ZmSuvWrc35fuutt8r+/fulvuP/d749hj4Ll5ctWybp6emSlZUl27dvlx49eshtt90mJ06c8FUJALxg/fr18sgjj8jHH38sq1evlgsXLsjQoUOlpKTE7tIAeNCWLVtkwYIF0r17d7tLAeABp06dkptuukkiIiJk1apVsmfPHnnppZekadOmdpcGoI6ef/55efnll2XevHmyd+9ec/uFF16QuXPn2l0aAIv092rNz7RZ0x09t+fMmSOvvPKKfPLJJ+JwOEzWdu7cOamPwsLCzHVpaandpQS8b7/91lzrz4rVCXH6KM7XTmXtbtQ3OHXp0iVJTEyUyZMny+OPP+6LEgD4wMmTJ00Hs4bOAwcOtLscAB5QXFwsvXr1kpycHJk+fbr07NlTZs+ebXdZAOpAf/7etGmTbNiwwe5SAHjYHXfcIfHx8fLHP/6x/L67777bdDW++uqrttYGoPa0c3n58uXm08JK4zztaP7Nb34jU6ZMMfedOXPGnP+5ubkyZswYqW/0mBw7dsw0vemxCQ1lGnBtjqEGy9oM3KRJE9MVX51w8QH9i8G2bdskIyOj/D79Bmu7/kcffeSLEgD4iL6Zqbi4OLtLAeAh+umEH/7wh+Z9W8NlAIEvLy/PdDbdc8895g/Cbdq0kUmTJsnEiRPtLg1AHfXv318WLlwon3/+uXTq1El27dolGzdulFmzZtldGgAPOnz4sPznP/8xP6O7xMbGmuZOzdrqY7isAbyGoXpsjh49anc5AU2D5VatWtXosT4Jl//73/+auU/615OK9Pa+fft8UQIAH9BPJOgsVv2YbdeuXe0uB4AHvPbaa2aclY7FABA8Dh06ZD42r2PrnnjiCXOOp6WlSYMGDWTcuHF2lwegjp9MKCoqkuuuu858RFx/F3/22Wdl7NixdpcGwIM0WFbusjbX1+oj/VkmOTmZ0Rh1oKMwXCNG/CZcBlB/uht3795tOiMABL7CwkL55S9/aeap62K8AILrD8K9e/eWGTNmmNs33HCDeQ/XmY2Ey0Bge/3112XJkiWydOlS6dKli+zcudM0gOhHxDm/AdQHOi2B3198xyfDR5o3b24S76+++qrS/Xq7pi3WAPzbo48+Ku+884588MEH0rZtW7vLAeABOtJKZ23pvOXw8HBz0Y/P66Ih+m/thAIQmPQjo9dff32l+zp37mzmFAIIbI899pjpXtaPxHfr1k1+8pOfyK9//WvJzs62uzQAHuTK08jaUC/CZW1JT0lJkbVr11bqltDb/fr180UJALw47F2DZV1Y4J///Kd06NDB7pIAeMgtt9win332mel4cl2001E/Vqv/tvJRKQD+RUdYFRQUVLpP57O2b9/etpoAeIYuxHT5Ilb6nq2/gwMIHvq7t4bIFbM2HYnzySefkLXBp3w2FkPnuelHcPSX0htvvNGsMl9SUiITJkzwVQkAvDQKQz9y9/bbb0t0dHT5bCddSEBXpAYQuPScvnx+usPhkGbNmjFXHQhw2sWoi37pWIzRo0fL5s2bzQJgegEQ2EaMGGFmLLdr186MxdixY4dZzO+BBx6wuzQAFhUXF8uBAwfKb+tCddrkERcXZ85xHXmjC27rjGENm5966ikzAmfkyJG21o36JcSpbYc+Mm/ePJk5c6YJn3r27Gk+VqurWAII7NVY3Vm8eLGMHz/e5/UA8K5BgwaZ93D9IzGAwKbjrDIyMmT//v3mF1JtBpk4caLdZQGoo7Nnz5qAST9ZqOOtNGi69957JTMz03yqGEDgWLdunQwePPiK+7V5Mzc313ySOCsry/xx+PTp0zJgwADJycmRTp062VIv6iefhssAAAAAAAAAgODgk5nLAAAAAAAAAIDgQrgMAAAAAAAAALCMcBkAAAAAAAAAYBnhMgAAAAAAAADAMsJlAAAAAAAAAIBlhMsAAAAAAAAAAMsIlwEAAAAAAAAAlhEuAwAAAAAAAAAsI1wGAAAAAAAAAFhGuAwAAAAAAAAAsIxwGQAAAAAAAAAgVv0frqLLft4nwY8AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1600x50 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TSTensor(samples:8, vars:3, len:5000, device=mps:0, dtype=torch.float32)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = np.zeros((10, 3, 5000))\n",
    "y = np.random.randint(0,2,X.shape[0])\n",
    "splits = get_splits(y)\n",
    "dls = get_ts_dls(X, y, splits=splits)\n",
    "xb, yb = dls.train.one_batch()\n",
    "xb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you try to use TSiTPlus, it's likely you'll get an 'out-of-memory' error.\n",
    "\n",
    "To avoid this you can subsample the sequence reducing the input's length. This can be done in multiple ways. Here are a few examples: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 3, 99])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Separable convolution (to avoid mixing channels)\n",
    "feature_extractor = Conv1d(xb.shape[1], xb.shape[1], ks=100, stride=50, padding=0, groups=xb.shape[1]).to(default_device())\n",
    "feature_extractor.to(xb.device)(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convolution (if you want to mix channels or change number of channels)\n",
    "feature_extractor=MultiConv1d(xb.shape[1], 64, kss=[1,3,5,7,9], keep_original=True).to(default_device())\n",
    "test_eq(feature_extractor.to(xb.device)(xb).shape, (xb.shape[0], 64, xb.shape[-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 3, 100])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# MaxPool\n",
    "feature_extractor = nn.Sequential(Pad1d((0, 50), 0), nn.MaxPool1d(kernel_size=100, stride=50)).to(default_device())\n",
    "feature_extractor.to(xb.device)(xb).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 3, 100])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# AvgPool\n",
    "feature_extractor = nn.Sequential(Pad1d((0, 50), 0), nn.AvgPool1d(kernel_size=100, stride=50)).to(default_device())\n",
    "feature_extractor.to(xb.device)(xb).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you decide what type of transform you want to apply, you just need to pass the layer as the feature_extractor attribute:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 16\n",
    "nvars = 4\n",
    "seq_len = 1000\n",
    "c_out = 2\n",
    "d_model = 128\n",
    "\n",
    "xb = torch.rand(bs, nvars, seq_len)\n",
    "feature_extractor = partial(Conv1d, ks=5, stride=3, padding=0, groups=xb.shape[1])\n",
    "model = TSiTPlus(nvars, c_out, seq_len, d_model=d_model, feature_extractor=feature_extractor)\n",
    "test_eq(model.to(xb.device)(xb).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Categorical variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.utils import alphabet, ALPHABET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = alphabet[np.random.randint(0,3,40)]\n",
    "b = ALPHABET[np.random.randint(6,10,40)]\n",
    "c = np.random.rand(40).reshape(4,1,10)\n",
    "map_a = {k:v for v,k in enumerate(np.unique(a))}\n",
    "map_b = {k:v for v,k in enumerate(np.unique(b))}\n",
    "n_cat_embeds = [len(m.keys()) for m in [map_a, map_b]]\n",
    "szs = [emb_sz_rule(n) for n in n_cat_embeds]\n",
    "a = np.asarray(a.map(map_a)).reshape(4,1,10)\n",
    "b = np.asarray(b.map(map_b)).reshape(4,1,10)\n",
    "inp = torch.from_numpy(np.concatenate((c,a,b), 1)).float()\n",
    "feature_extractor = partial(Conv1d, ks=3, padding='same')\n",
    "model = TSiTPlus(3, 2, 10, d_model=64, cat_pos=[1,2], feature_extractor=feature_extractor)\n",
    "test_eq(model(inp).shape, (4,2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sequence Embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sometimes you have a samples with a very long sequence length. In those cases you may want to reduce it's length before passing it to the transformer. To do that you may just pass a token_size like in this example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 128, 168])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.rand(8, 2, 10080)\n",
    "SeqTokenizer(2, 128, 60)(t).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 5])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "t = torch.rand(8, 2, 10080)\n",
    "model = TSiTPlus(2, 5, 10080, d_model=64, token_size=60)\n",
    "model(t).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/068_models.TSiTPlus.ipynb saved at 2025-03-01 15:24:45\n",
      "Correct notebook to script conversion! 😃\n",
      "Saturday 01/03/25 15:24:48 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
